Listing Mortality Prediction

Author

Michael D. Porter

Published

July 5, 2024

Data

  1. Load training and test data
Code
data_train = read_rds(file.path(dir_data, "model_data_train.rds")) 
data_test = read_rds(file.path(dir_data, "model_data_test.rds")) 
  1. Clean Data
    • Remove outliers in Creatinine (and eGFR).
    • Remove outliers in median_refusals
    • impute missing height, weight, and bmi using Age and Gender
    • Remove lengthy Functional Status descriptions
Code
#: table of sizes from data (by AGE and GENDER)
impute_size_data = bind_rows(data_train, data_test) %>% 
  group_by(AGE, GENDER) %>% 
  summarize(
    WEIGHT_KG = median(WEIGHT_KG, na.rm=TRUE), 
    HEIGHT_CM = median(HEIGHT_CM, na.rm=TRUE), 
    .groups = "drop"
  )
#: function to impute missing height and weight (using AGE, GENDER)
impute_size <- function(var, AGE, GENDER){
  var = match.arg(var, c("HEIGHT_CM", "WEIGHT_KG"))
  X = tibble(AGE, GENDER) %>% 
    left_join(impute_size_data, by = c("AGE", "GENDER"))
  if(var == "HEIGHT_CM") X$HEIGHT_CM else X$WEIGHT_KG
}

# outliers in CREAT and median refusals
clean_data <- function(df){
  df %>% 
    mutate(
      outlier = MOST_RCNT_CREAT > 8,
      eGFR = ifelse(outlier, eGFR*MOST_RCNT_CREAT/8, eGFR),
      MOST_RCNT_CREAT = pmin(MOST_RCNT_CREAT, 8)
    ) %>% 
    select(-outlier) %>% 
    #: outliers in median refusals
    mutate(
      median_refusals = pmin(median_refusals, 20)
    ) %>% 
    #: impute height, weight, and bmi
    mutate(
      HEIGHT_CM = coalesce(HEIGHT_CM, impute_size("HEIGHT_CM", AGE, GENDER)),
      WEIGHT_KG = coalesce(WEIGHT_KG, impute_size("WEIGHT_KG", AGE, GENDER)),
      BMI = coalesce(BMI, WEIGHT_KG / (HEIGHT_CM/100)^2),
      BSA = coalesce(BSA, sqrt(HEIGHT_CM * WEIGHT_KG / 3600)),
    ) %>% 
    #: remove Functional Status Descriptions; only keep %
    mutate(
      FUNC_STAT_CAND_REG = str_replace(FUNC_STAT_CAND_REG, "(\\d+%).+", "\\1")
    )
}

Update data.

Code
data_train = data_train %>% clean_data()
data_test = data_test %>% clean_data()
  1. Create 10-fold cv of training data.
Code
set.seed(2024)
cv_folds = rsample::vfold_cv(data_train, v = 10, strata = outcome)

Predictive Modeling

Create baseline preprocessing recipe and set predictor variables. All models start with this recipe.

  • Sets predictors and outcome variables
  • Converts Diabetes to {Yes, No}
  • Converts Functional Status to numeric, adds indicator for baby and unknown
  • Cleans eGFR: removes outliers, code based on kidney risk, binary for eGFR < 60
  • Creates binary for Albumin < 3
Code
library(tidymodels)

base_rec = 
  #: Set formula
  recipe(outcome ~ ., data = head(data_train)) %>%
  # step_naomit(outcome) %>%       # remove rows with missing outcome
  step_mutate(outcome = factor(outcome, levels = c(1, 0)), skip=TRUE) %>% 
  #: Remove variables from all models
  step_rm(WL_ID_CODE) %>% 
  step_rm(
    matches("DONCRIT_.+_AGE"), 
    matches("DONCRIT_.+_HGT"),  
    matches("DONCRIT_.+_WGT"), 
    matches("DONCRIT_.+_MILE"), 
    matches("DONCRIT_"), # remove all DONCRIT variables. They are primarily
    # recorded for a few UNOS REGIONS. 
    # --- not clinical
    CITIZENSHIP, 
    HEMODYNAMICS_CO, # lots of missing
    CEREB_VASC, 
    # ---
    CAND_DIAG_LISTING, # use CAND_DIAG instead
    CAND_DIAG_CODE,    # use CAND_DIAG instead 
    MOST_RCNT_CREAT,   # use eGFR instead
    VAD_DEVICE_TY_TCR, # not enough info
    WL_DT,             # use LIST_YR for temporal information
    #------------------------------------- Substitutes
    # LC_effect,
    median_wait_days,       # use LC_effect instead
    median_wait_days_1A,    # use LC_effect instead
    median_wait_days_STATUS,# use LC_effect instead
    median_refusals_old, # use median_refusals instead
    mean_refusals,       # use median_refusals instead
    p_refusals,          # use median_refusals instead
    #-------------------------------------
    # LISTING_CTR_CODE, # Let individual models choose
    REGION,           # Some regions only have 1-2 centers
    ## Requested by Dr. Haregu
    LIFE_SUPPORT_OTHER,
    PGE_TCR,
    LIST_YR,
  ) %>% 
  #: Additional cleaning
  step_mutate(
    # convert Diabetes to Yes = 1, No = 0
    DIAB = case_match(DIAB, 
                      "None" ~ 0L, 
                      "Unknown" ~ 0L, 
                      .default = 1L)
  ) %>% 
  #: Convert Functional status to numeric; add indicators for missing and baby
  step_mutate(
    FUNC_STAT_NUM = str_extract(FUNC_STAT_CAND_REG, "(\\d+)%", group = 1) %>%
      as.numeric() %>% coalesce(0),
    FUNC_STAT_UNKNOWN = ifelse(FUNC_STAT_CAND_REG == "Unknown", 1L, 0L),
    CAND_UNDER_1 = ifelse(FUNC_STAT_CAND_REG == "Not Applicable (patient < 1 year old)", 1L, 0L),
  ) %>% 
  step_rm(FUNC_STAT_CAND_REG) %>% # remove the original variable
  #: Cutoffs for eGFR (Kidney Disease) and Albumin (Nutrition)
  step_mutate(
    eGFR = pmin(eGFR, 250),  # fix outliers for non-tree models
    eGFR_CODED = case_when(
      eGFR > 120 ~ 0,
      eGFR >= 90 ~ 1, 
      eGFR >= 60 ~ 2,
      eGFR >= 45 ~ 3, 
      eGFR >= 30 ~ 3.5, 
      eGFR >= 15 ~ 4, 
      eGFR <  15 ~ 5, 
      .default = 0),  # if missing, assume eGFR is good
    eGFR_UNDER_60 = ifelse(eGFR < 60, 1, 0) %>% 
      coalesce(0), # if missing, assume eGFR is good (above 60)
    ALBUM_UNDER_3 = ifelse(TOT_SERUM_ALBUM < 3, 1, 0) %>% 
      coalesce(0) # if missing, assume Albumin is good (above 3)
  )
  #: outliers in median refusals
  # step_mutate(median_refusals = pmin(median_refusals, 20)) 

Logistic Regression

1. Tidymodels specification

Lasso logistic regression model.

  1. Remove LISTING_CTR_CODE
  2. Add new missing indicator feature for all variables with missing
  3. Convert Functional Status to number {0, 10, …, 100}. Add indicator for Unknown status.
  4. One-hot encode all categorical predictors. For variables with {Yes, No, Unknown}, only keep the Yes column. This lumps the Unknown with No. 
  5. Truncate median_refusals to 20.
  6. Impute all missing values with median.
  7. Create new features by coding eGFR into stages of chronic kidney failure.
  8. Create binary TOT_SERUM_ALBUM < 3 indicator.
  9. Add polynomial terms for the numeric features.
Code
library(tidymodels)

# Model specification: Lasso penalized logistic regression
lasso_spec = 
  logistic_reg() %>%
  set_engine("glmnet") %>%
  set_args(
    mixture = 1,    # 1 = lasso, 0 = ridge
    penalty = tune()
  ) 
  
# Recipe:
lasso_rec = 
  base_rec %>% 
  #: Remove additional variables for this model
  step_rm(
    LISTING_CTR_CODE,
  ) %>%
  #: Add additional variables to represent missing predictors
  step_indicate_na(all_predictors()) %>% 
  #: Convert categorical predictors to dummy 
  step_dummy(all_nominal_predictors(), one_hot = TRUE) %>% # one-hot
  step_dummy(all_ordered_predictors(), one_hot = TRUE) %>% # one-hot
  step_rm(ends_with("_No")) %>%       # removes the No level (Yes is binary)
  step_rm(ends_with("_Unknown")) %>%  # removes the Unknown level (Yes is binary).
                                # This effectively treats "Unknown" same as "No"
  #: Impute missing values
  step_impute_median(all_numeric_predictors()) %>% 
  #: Add polynomial terms
  step_poly(
    AGE, WEIGHT_KG, HEIGHT_CM, BMI, BSA,
    eGFR,
    # TOT_SERUM_ALBUM,
    FUNC_STAT_NUM,
    # pedhrtx_prev_yr, 
    # median_refusals,
    # LC_effect
    degree = 2,
  ) %>% 
  #: Remove all zero variance predictors (e.g., from step_indicate_na() )
  step_zv(all_predictors()) %>%      
  #: Scale all predictors for variable importance scoring
  step_scale(all_numeric_predictors()) 

# Workflow:
lasso_wflow = 
  workflow(preprocessor = lasso_rec, spec = lasso_spec)

2. Preprocesses

Code
# Pre-process training data
lasso_rec_fitted = prep(lasso_rec, data_train)
X_train_lasso = bake(lasso_rec_fitted, new_data = NULL, 
               all_predictors(), composition = "matrix")
Y_train_lasso = bake(lasso_rec_fitted, new_data = NULL, 
               all_outcomes()) %>% pull()

3. Tune lambda using cross-validation

Tuning the \(\lambda\) (or penalty) parameter using cross-validation to maximize the AUC.

Code
# Create 10 fold cv indices
# set.seed(100)
# folds = sample(rep(1:10, length = nrow(X_train_lasso)))

folds = cv_folds %>% broom::tidy() %>% filter(Data == "Assessment") %>%
  arrange(Row) %>% mutate(Fold = readr::parse_number(Fold)) %>% pull(Fold)

# Run 10-fold CV on training data to estimate lambda
library(glmnet)
cv_fit = cv.glmnet(
  X_train_lasso, Y_train_lasso,
  family = "binomial",
  alpha = 1,    # lasso
  relax = FALSE,
  foldid = folds,
  type.measure = "auc")

4. CV performance

Code
# plot(cv_fit)
cv_fit

Call:  cv.glmnet(x = X_train_lasso, y = Y_train_lasso, type.measure = "auc",      foldid = folds, relax = FALSE, family = "binomial", alpha = 1) 

Measure: AUC 

      Lambda Index Measure      SE Nonzero
min 0.002341    34  0.7443 0.01364      31
1se 0.018122    12  0.7321 0.01366       9

5. Test Performance

Code
X_test_lasso = bake(lasso_rec_fitted, new_data = data_test, 
              all_predictors(), composition = "matrix")

data_test %>% 
  transmute(
    WL_ID_CODE, 
    outcome = factor(outcome, c(1,0)), 
    lasso.min =  1-predict(cv_fit, X_test_lasso, type = "response", 
                           s = "lambda.min")[,1], 
    lasso.1se = 1-predict(cv_fit, X_test_lasso, type = "response", 
                          s = "lambda.1se")[,1],
    lasso.unpenalized = 1-predict(cv_fit, X_test_lasso, type = "response", 
                          s = 0)[,1],
  ) %>% 
  pivot_longer(starts_with("lasso")) %>% group_by(name) %>% 
  reframe(calc_metrics(outcome, value)) %>% 
  arrange(-auc) %>% 
  mutate_all(\(x) digits(x, 3))

6. Variable Importance

Code
vip::vi(cv_fit, lambda = cv_fit$lambda.min) %>% 
  filter(Importance > 1E-8) 
Code
# Plot at (1se or min) lambda 
vip::vi(cv_fit, lambda = cv_fit$lambda.min) %>%   # cv_fit$lambda.1se
  filter(Importance > 1E-8) %>% 
  mutate(
    Variable = Variable %>% 
      str_replace("_poly_1", " (linear)") %>% 
      str_replace("_poly_2", " (quadratic)") %>% 
      str_replace("_poly_3", " (cubic)") %>% 
      # str_replace_all("_", " ") %>% 
      str_replace("(CAND_DIAG)_(.+)", "\\1: \\2") %>% 
      str_replace_all("\\.", " ") %>% 
      str_wrap(30, whitespace_only = FALSE),
    Variable = fct_reorder(Variable, abs(Importance)),
    Importance = ifelse(Sign == "NEG", -Importance, Importance),
    ) %>% 
  ggplot(aes(Importance, Variable, color = Sign)) + 
  geom_point() + 
  geom_segment(aes(xend = 0, yend = Variable)) + 
  scale_color_brewer(type = "qual", palette = 2) + 
  labs(y=  "", title = "Predicting Waitlist Survival")

7. Additive Effects Plot

Code
#: Get final fit with tuned lambda/penalty
lasso_fit = lasso_wflow %>% 
  finalize_workflow(tibble(penalty = cv_fit$lambda.min)) %>% 
  fit(data_train)
Code
get_raw_variable_names <- function(x){
  x %>% 
      str_remove("na_ind_") %>% 
      str_remove("_poly_\\d") %>% 
      str_remove("_Unknown") %>% 
      str_remove("_Yes") %>% 
      str_replace("FUNC_STAT_NUM", "FUNC_STAT_CAND_REG") %>% 
      str_replace("ABO(_.+)", "ABO") %>% 
      str_replace("CAND_DIAG_CODE(_.+)", "CAND_DIAG_CODE") %>% 
      {ifelse(. == "CAND_DIAG_CODE", ., str_replace(.,"CAND_DIAG(_.+)", "CAND_DIAG"))} %>%
      str_replace("LIFE_SUPPORT_CAND_REG(_.+)", "LIFE_SUPPORT_CAND_REG") %>% 
      str_replace("eGFR(_.+)", "eGFR") %>% # since eGFR_CODED isn't in data
      str_replace("RACE(_.+)", "RACE") %>% 
      str_replace("GENDER(_.+)", "GENDER") %>%
      str_replace("STATUS(_.+)", "STATUS") %>%
      str_replace("LISTING_CTR_CODE(_X.+)", "LISTING_CTR_CODE") %>% 
      str_replace("ALBUM_UNDER_3",  "TOT_SERUM_ALBUM") %>% 
      str_replace("CAND_UNDER_1",  "AGE") %>% 
      str_replace("REGION(_X.+)", "REGION") %>% 
      str_replace("LISTING_CTR_CODE(_X.+)", "LISTING_CTR_CODE")
}
Code
imp_features = vip::vi(cv_fit, lambda = cv_fit$lambda.min) %>% 
  filter(Importance > 1E-8) %>% 
  mutate(var_raw = get_raw_variable_names(Variable)) %>% 
  mutate(.by = var_raw, total_importance = mean(Importance)) %>% 
  arrange(-total_importance, -Importance)

imp_vars = imp_features %>% distinct(var_raw) %>% pull()

walk(imp_vars, plot_additive_effects, model = lasso_fit)

GAM

1. Tidymodels recipe

Code
# ?details_gen_additive_mod_mgcv
library(mgcv)
# Recipe:
gam_rec = 
  base_rec %>% 
  #: Remove additional variables for this model
  step_rm(
    LISTING_CTR_CODE,
  ) %>% 
  #: Add additional variables to represent missing predictors
  step_indicate_na(all_predictors()) %>% 
  #: Impute missing values
  step_impute_median(all_numeric_predictors()) %>% 
  #: Remove all zero variance predictors (e.g., from step_indicate_na() )
  step_zv(all_predictors()) %>%      
  #: Scale all predictors for variable importance scoring
  step_scale(all_numeric_predictors()) 

#: train recipe and create X matrix
rec_fitted = prep(gam_rec, data_train)
data_train_gam = bake(rec_fitted, new_data = NULL)

2. Fit GAM

Code
fit_gam = mgcv::gam(outcome ~ 
                s(HEIGHT_CM) + 
                s(AGE) + 
                # s(AGE,HEIGHT_CM) + 
                # s(AGE, WEIGHT_KG) + 
                # s(BMI) + s(WEIGHT_KG) + 
                # s(GENDER, bs = "re") + 
                # s(HEIGHT_CM_PERC) +
                # s(WEIGHT_KG_PERC) + 
                s(RACE, bs = "re") + 
                # s(CITIZENSHIP, bs = "re") + 
              # s(STATUS, bs = "re") + 
                s(ABO, bs = "re") + 
                s(LIFE_SUPPORT_CAND_REG, bs = "re") +
                # s(LIFE_SUPPORT_OTHER, bs = "re") +
                # s(PGE_TCR, bs = "re") +
                ECMO_CAND_REG + #s(ECMO_CAND_REG, bs = "re") +
                s(VAD_CAND_REG, bs = "re") +
                VENTILATOR_CAND_REG +  #s(VENTILATOR_CAND_REG, bs = "re") +
               # s(FUNC_STAT_CAND_REG, bs = "re") +
                s(FUNC_STAT_NUM) + 
                s(FUNC_STAT_UNKNOWN, bs = "re") + 
                s(CAND_UNDER_1, bs = "re") + 
                # s(WL_OTHER_ORG, bs = "re") +
                # s(CEREB_VASC, bs = "re") +
                # s(DIAB, bs = "re") +
                s(DIALYSIS_CAND, bs = "re") +
                # s(HEMODYNAMICS_CO,  bs = "re") +
                # s(IMPL_DEFIBRIL, bs = "re") +
                s(INOTROP_VASO_CO_REG, bs = "re") +
                # s(INOTROPES_TCR, bs = "re") +
                # s(MOST_RCNT_CREAT) +
                eGFR_CODED + 
                # I(eGFR < 60) + 
                # s(eGFR) +
                I(TOT_SERUM_ALBUM < 3) +
                # s(TOT_SERUM_ALBUM) +
                s(CAND_DIAG, bs = "re") +
                # s(WL_OTHER_ORG, bs = "re") + 
              # s(LISTING_CTR_CODE, bs = "re") +
                # s(LIST_YR) +
                # s(REGION, bs = "re") +
                # s(LC_effect, k=4) + 
                # s(median_wait_days_1A, k = 3) + 
                # s(median_refusals, k = 3) + 
                # s(pedhrtx_prev_yr, k = 3) + 
                LC_effect +  
                median_refusals +  
                pedhrtx_prev_yr,
              # select=TRUE,
              method = "GCV.Cp", 
              data = data_train_gam, 
              family = binomial())

3. Variable Importance

Code
summary(fit_gam)

Family: binomial 
Link function: logit 

Formula:
outcome ~ s(HEIGHT_CM) + s(AGE) + s(RACE, bs = "re") + s(ABO, 
    bs = "re") + s(LIFE_SUPPORT_CAND_REG, bs = "re") + ECMO_CAND_REG + 
    s(VAD_CAND_REG, bs = "re") + VENTILATOR_CAND_REG + s(FUNC_STAT_NUM) + 
    s(FUNC_STAT_UNKNOWN, bs = "re") + s(CAND_UNDER_1, bs = "re") + 
    s(DIALYSIS_CAND, bs = "re") + s(INOTROP_VASO_CO_REG, bs = "re") + 
    eGFR_CODED + I(TOT_SERUM_ALBUM < 3) + s(CAND_DIAG, bs = "re") + 
    LC_effect + median_refusals + pedhrtx_prev_yr

Parametric coefficients:
                           Estimate Std. Error z value Pr(>|z|)    
(Intercept)                 3.66681    4.34068   0.845   0.3982    
ECMO_CAND_REG              -0.18039    0.04192  -4.303 1.68e-05 ***
VENTILATOR_CAND_REG        -0.23476    0.05014  -4.682 2.84e-06 ***
eGFR_CODED                 -0.08173    0.05719  -1.429   0.1530    
I(TOT_SERUM_ALBUM < 3)TRUE -0.02178    0.22014  -0.099   0.9212    
LC_effect                   0.07760    0.05655   1.372   0.1700    
median_refusals            -0.28123    0.04833  -5.819 5.93e-09 ***
pedhrtx_prev_yr             0.12694    0.05804   2.187   0.0287 *  
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Approximate significance of smooth terms:
                               edf Ref.df Chi.sq  p-value    
s(HEIGHT_CM)             3.264e+00  4.149 28.276 1.57e-05 ***
s(AGE)                   1.003e+00  1.005 10.885  0.00101 ** 
s(RACE)                  1.731e+00  4.000  7.958  0.00545 ** 
s(ABO)                   1.388e+00  3.000  2.522  0.15613    
s(LIFE_SUPPORT_CAND_REG) 2.480e-05  2.000  0.000  0.56455    
s(VAD_CAND_REG)          3.183e-05  1.000  0.000  0.79378    
s(FUNC_STAT_NUM)         1.000e+00  1.000  0.151  0.69733    
s(FUNC_STAT_UNKNOWN)     3.331e-05  1.000  0.000  0.53727    
s(CAND_UNDER_1)          7.148e-01  1.000  3.851  0.01657 *  
s(DIALYSIS_CAND)         1.781e+00  2.000 13.156  0.00083 ***
s(INOTROP_VASO_CO_REG)   1.790e+00  2.000  9.071  0.00832 ** 
s(CAND_DIAG)             3.711e+00  7.000 57.028  < 2e-16 ***
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

R-sq.(adj) =  0.0946   Deviance explained = 13.1%
UBRE = -0.43619  Scale est. = 1         n = 4523
Code
bind_rows(tidy(fit_gam), tidy(fit_gam, parametric=TRUE)) %>% 
  arrange(p.value) %>% 
  transmute(
    var = str_replace(term, "s\\((.+)\\)", "\\1"), 
    # term, 
    edf = coalesce(edf, 1) %>% digits(3), 
    p.value = digits(p.value,4)
  ) 

4. Test Performance

Code
data_test_gam = bake(rec_fitted, new_data = data_test)

data_test_gam %>%
  transmute(
    outcome = factor(outcome, c(1,0)), 
    p_gam =  1-predict(fit_gam, ., type = "response") %>% as.numeric
  ) %>% 
  reframe(calc_metrics(outcome, p_gam)) %>% 
  mutate_all(\(x) digits(x, 3))

5. Partial Dependence Plots

Code
library(gratia)
gam_plot_data = bind_rows(
    # Smooth effects
    gratia::smooth_estimates(fit_gam, unnest = FALSE) %>% 
    mutate(var = map_chr(data, \(x) tail(colnames(x),1)), .before = 1) %>% 
    select(-.smooth), 
    # unpenalized
    gratia::parametric_effects(fit_gam, unnest = FALSE) %>% 
    mutate(.type = "unpenalized") %>% 
    rename(var = .term)
  ) %>% 
  # Add edf and p.value
  left_join(
    bind_rows(tidy(fit_gam), tidy(fit_gam, parametric=TRUE)) %>% 
      arrange(p.value) %>% 
      transmute(
        var = str_replace(term, "s\\((.+)\\)", "\\1") %>% 
          str_remove("TRUE"), 
        edf = coalesce(edf, 1) %>% digits(3), 
        p.value = digits(p.value,4)
      ), 
    by = "var"
  ) %>% 
  arrange(p.value)
Code
plot_gam_effects <- function(select = 1){
  df = gam_plot_data %>% 
    slice(!!select) %>% 
    unnest(data)
  var_name = df$var[1]
  if(df$.type[1] == "unpenalized"){
    df = df %>% mutate( !!var_name := .value, .estimate = .partial)
  }
  categorical = is.character(df[[var_name]]) | is.factor(df[[var_name]]) | nrow(df) < 6
  if(categorical) {
    plt = plot_categorical_effects(df[[var_name]], df[[".estimate"]], xlab = var_name)
  } else{
    plt = plot_numeric_effects(df[[var_name]], df[[".estimate"]], xlab = var_name)
  }
  
  print(
    plt +
    labs(y = "partial effect") +
    scale_y_continuous(breaks = seq(-10, 10, by = .25)) +
    coord_cartesian(ylim = c(-1,1))
  )  
  
}
Code
walk(1:nrow(gam_plot_data), \(i) plot_gam_effects(i))

Random Forest

1. Tidymodels specification

Code
# Model specification: Random Forest
rf_spec = 
  rand_forest() %>%
  set_mode("classification") %>%
  set_engine("ranger", 
    seed = 2024, 
    importance = "impurity", #"none", 
    num.threads = 8
  ) %>% 
  set_args(
    mtry = tune(), 
    trees = 2000, 
    min_n = 2
  ) 

# Recipe:
rf_rec = 
  base_rec %>% 
  #: Remove additional variables for this model
  step_rm(
    LISTING_CTR_CODE,
  ) %>% 
  #: Add additional variables to represent missing predictors
  step_indicate_na(all_predictors()) %>% 
  #: Convert categorical predictors to dummy 
  step_dummy(all_nominal_predictors(), one_hot = TRUE) %>% # one-hot
  # step_dummy(all_ordered_predictors(), one_hot = TRUE) %>% # one-hot
  step_rm(ends_with("_No")) %>%       # removes the No level (Yes is binary)
  step_rm(ends_with("_Unknown")) %>%  # removes the Unknown level (Yes is binary).
                                # This effectively treats "Unknown" same as "No"
  #: Impute missing values
  step_impute_median(all_numeric_predictors()) %>% 
  #: Remove all zero variance predictors (e.g., from step_indicate_na() )
  step_zv(all_predictors())

# Workflow:
rf_wflow = 
  workflow(preprocessor = rf_rec, spec = rf_spec)

2. Tune mtry

Use OOB observations for tuning instead of cross-validation.

Code
# pre-tuning: use oob brier score to seed the full cross-val grid search
#             this is used to speed along tune_grid() 
oob_brier <- function(mtry){
  # according to help(ranger), brier metric is used for oob error
  fit = rf_wflow %>% 
    finalize_workflow(list(mtry = mtry)) %>% 
    fit(data_train) 
  tibble(
    mtry, 
    oob_error = fit %>% extract_fit_engine() %>% pluck("prediction.error")
  )
}
num_cols = prep(rf_rec, data_train)$term_info %>% nrow()    # number of features
mtry_max = min(num_cols, 2*sqrt(num_cols)) %>% floor() # max mtry to try
mtry_oob_grid = seq(1, mtry_max, length=50) %>% floor() %>% unique()
oob_perf = map_df(mtry_oob_grid, oob_brier) # get oob performance
mtry_grid = oob_perf %>% 
  slice_min(oob_error, n = 5) %>% # keep best 5 mtry values
  select(mtry)
Code
# set.seed(1000)
# tune_res = tune_grid(
#   object = rf_wflow,
#   resamples = cv_folds,
#   grid = mtry_grid,
#   metrics = metric_set(roc_auc, brier_class, mn_log_loss, accuracy), 
#   control = control_grid(verbose=FALSE)
# )

3. CV performance

Code
#: select from oob
rf_tune = mtry_grid %>% slice_min(mtry_grid, n=1, with_ties = FALSE)

set.seed(1000)
fit_resamples(
  object = rf_wflow %>% finalize_workflow(rf_tune),
  resamples = cv_folds,
  metrics = metric_set(roc_auc, brier_class, mn_log_loss, accuracy), 
  control = control_grid(verbose=FALSE)  
) %>% collect_metrics()

4. Test performance

Fit random forest

Code
rf_fit = rf_wflow %>% 
  finalize_workflow(rf_tune) %>% 
  fit(data_train)

Test performance

Code
data_test %>% 
  transmute(
    WL_ID_CODE, 
    outcome = factor(outcome, c(1,0)), 
    p_rf =  predict(rf_fit, data_test, type = "prob")$.pred_1
  ) %>% 
  reframe(calc_metrics(outcome, p_rf)) %>% mutate_all(\(x) digits(x, 3))

5. Variable Importance

Code
rf_fit
══ Workflow [trained] ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: rand_forest()

── Preprocessor ──────────────────────────────────────────────────────────────────────────
14 Recipe Steps

• step_mutate()
• step_rm()
• step_rm()
• step_mutate()
• step_mutate()
• step_rm()
• step_mutate()
• step_rm()
• step_indicate_na()
• step_dummy()
• ...
• and 4 more steps.

── Model ─────────────────────────────────────────────────────────────────────────────────
Ranger result

Call:
 ranger::ranger(x = maybe_data_frame(x), y = y, mtry = min_cols(~2,      x), num.trees = ~2000, min.node.size = min_rows(~2, x), seed = ~2024,      importance = ~"impurity", num.threads = ~8, verbose = FALSE,      probability = TRUE) 

Type:                             Probability estimation 
Number of trees:                  2000 
Sample size:                      4523 
Number of independent variables:  53 
Mtry:                             2 
Target node size:                 2 
Variable importance mode:         impurity 
Splitrule:                        gini 
OOB prediction error (Brier s.):  0.08212073 
Code
rf_fit %>% extract_fit_parsnip() %>% vip::vi() %>% filter(Importance > 1E-8)

XGBoost

1. Tidymodels specification

Fixing learn_rate(eta) to 0.10, sample_size = 0.80, and tuning tree_depth and trees. Fixing learn_rate and tuning the number of trees is good for efficiency due to the multi-predict capabilities.

  1. Remove LISTING_CTR_CODE
  2. One-hot encode all nominal predictors. For variables with {Yes, No, Unknown}, only keep the Yes column. This lumps the Unknown with No. 
  3. Let xgboost handle missing values
Code
library(xgboost)
library(tidymodels)

# Model specification: XGBoost
bt_spec = 
  boost_tree() %>% 
  set_mode("classification") %>% 
  set_engine("xgboost",  nthread = 8) %>% 
  set_args(
    trees = tune(), 
    tree_depth = tune(),
    learn_rate = 0.10, # fixed 
    sample_size = .80, # fixed
  ) 

# Recipe:
## Let xgboost handle missing values internally; does not impute
bt_rec = 
  base_rec %>% 
  step_rm(
    LISTING_CTR_CODE,
  ) %>% 
  # step_dummy(LISTING_CTR_CODE, one_hot = TRUE) %>% 
  #: Convert categorical predictors to dummy 
  step_dummy(all_nominal_predictors(), one_hot = TRUE) %>% # one-hot
  # step_dummy(all_ordered_predictors(), one_hot = TRUE) %>% # one-hot
  step_rm(ends_with("_No")) %>%       # removes the No level (Yes is binary)
  step_rm(ends_with("_Unknown")) %>%  # removes the Unknown level (Yes is binary).
                                # This effectively treats "Unknown" same as "No"
  #: Remove all zero variance predictors (e.g., from step_indicate_na() )
  step_zv(all_predictors())    

# Workflow:
bt_wflow = 
  workflow(preprocessor = bt_rec, spec = bt_spec)

2. Tuning

Code
# Tuning grid. Use fewer trees from larger depth
bt_grid = 
  bind_rows(
    tibble(trees = seq(25, 300, by = 5), tree_depth = 1), 
    tibble(trees = seq(10, 150, by = 5), tree_depth = 2),
    tibble(trees = seq(10, 75, by = 5), tree_depth = 3), 
    tibble(trees = seq(10, 50, by = 5), tree_depth = 4), 
  )
# expand_grid(trees = seq(25, 250, by = 10), tree_depth = 1:4)

#: don't use bayes here since it won't exploit the multi-predict efficiency
set.seed(1000)
tune_res = tune_grid(
  object = bt_wflow,
  resamples = cv_folds,
  grid = bt_grid,
  metrics = metric_set(roc_auc, brier_class, mn_log_loss, accuracy), 
  control = control_grid(verbose=FALSE)
)

3. CV performance

Code
tune_res %>% show_best(metric = "roc_auc")
Code
tune_res %>% collect_metrics() %>% filter(.metric == "roc_auc") %>% 
  arrange(trees) %>% 
  ggplot(aes(trees, mean, color = factor(tree_depth))) + 
  geom_point() + 
  geom_line() + 
  labs(x = "Number of Trees", y = "Avg AUC", color = "tree depth")

4. Final model fit

Code
# Final model fit
(bt_tune = tune_res %>% select_best(metric = "roc_auc"))
Code
set.seed(1234)
bt_fit = finalize_workflow(bt_wflow, bt_tune) %>% fit(data_train)

5. Variable Importance

Code
bt_fit
══ Workflow [trained] ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: boost_tree()

── Preprocessor ──────────────────────────────────────────────────────────────────────────
12 Recipe Steps

• step_mutate()
• step_rm()
• step_rm()
• step_mutate()
• step_mutate()
• step_rm()
• step_mutate()
• step_rm()
• step_dummy()
• step_rm()
• ...
• and 2 more steps.

── Model ─────────────────────────────────────────────────────────────────────────────────
##### xgb.Booster
raw: 107.6 Kb 
call:
  xgboost::xgb.train(params = list(eta = 0.1, max_depth = 1, gamma = 0, 
    colsample_bytree = 1, colsample_bynode = 1, min_child_weight = 1, 
    subsample = 0.8), data = x$data, nrounds = 140, watchlist = x$watchlist, 
    verbose = 0, nthread = 8, objective = "binary:logistic")
params (as set within xgb.train):
  eta = "0.1", max_depth = "1", gamma = "0", colsample_bytree = "1", colsample_bynode = "1", min_child_weight = "1", subsample = "0.8", nthread = "8", objective = "binary:logistic", validate_parameters = "TRUE"
xgb.attributes:
  niter
callbacks:
  cb.evaluation.log()
# of features: 49 
niter: 140
nfeatures : 49 
evaluation_log:
    iter training_logloss
       1        0.6309374
       2        0.5796070
---                      
     139        0.2750236
     140        0.2749684
Code
bt_fit %>% extract_fit_parsnip() %>% vip::vi() %>% filter(Importance > 1E-8)
Code
data_test %>% 
  transmute(
    WL_ID_CODE, 
    outcome = factor(outcome, c(1,0)), 
    p_bt =  predict(bt_fit, data_test, type = "prob")$.pred_1
  ) %>% 
  reframe(calc_metrics(outcome, p_bt)) %>% mutate_all(\(x) digits(x, 3))

6. SHAP Dependence Plot

Get SHAP values

Code
data_train_xgb = bt_rec %>% 
  prep(data_train) %>% 
  bake(data_train, all_predictors(), composition = "matrix")

bt_shap = predict(
  bt_fit %>% extract_fit_engine(), 
  data_train_xgb, 
  predcontrib = TRUE) %>% 
  as_tibble()

SHAP Importance (Mean absolute deviation)

Code
bt_shap_imp_features = bt_shap %>% select(-BIAS) %>% 
  map_dbl(\(x) mean(abs(x))) %>% enframe(name = "feature", value="shap_imp") %>% 
  filter(shap_imp > 0) %>% arrange(-shap_imp)

bt_shap_imp_features

SHAP dependence plots (NOTE: not showing categorical)

Code
plot_shap_effects <- function(var, shap=bt_shap, data = data_train_xgb){
  var = rlang::ensym(var)
  df = data_train_xgb %>% as_tibble() %>% 
    select({{var}}) %>% 
    mutate(
      shap = -pull(shap, {{var}})
    ) 
  rug_data = df %>% count({{var}}) %>% rename(x = {{var}})
  df = distinct(df)
  n_bks = nrow(rug_data)
  categorical = is.character(df[[1]]) | is.factor(df[[1]])
  if(n_bks > 6 & !categorical ){
    plt = plot_numeric_effects(df[[1]], df[[2]], var, rug_data)
  } else{
    plt = plot_categorical_effects(df[[1]], df[[2]], var, rug_data)
  }
  
  print(
    plt +
    labs(y = "partial effect") +
    scale_y_continuous(breaks = seq(-10, 10, by = .25)) +
    coord_cartesian(ylim = c(-1,1))
  )
}


# plot_shap_effects("CAND_DIAG_Congenital.Heart.Disease.With.Surgery")

walk(bt_shap_imp_features$feature, plot_shap_effects)

Details on eGFR impact

Code
# eGFR deep dive
plot_shap_effects(eGFR) + scale_x_continuous(breaks = seq(0, 1000, by = 15)) +
  geom_vline(xintercept = c(15, 30, 45, 60, 90), color = "purple", alpha = .25)

Code
data_train_xgb %>% as_tibble() %>% 
  select(eGFR) %>% 
  mutate(
    shap = -pull(bt_shap, eGFR)    
  ) %>% 
  distinct() %>% arrange(eGFR) %>% 
  filter(lag(shap, default=0) != shap)

Details on Albumin impact

Code
# ALBUM deep dive
plt = plot_shap_effects(TOT_SERUM_ALBUM) 

Code
plt + scale_x_continuous(breaks = seq(0, 10, by=1))

Code
data_train_xgb %>% as_tibble() %>% 
  select(TOT_SERUM_ALBUM) %>% 
  mutate(
    shap = -pull(bt_shap, TOT_SERUM_ALBUM)    
  ) %>% 
  distinct() %>% arrange(TOT_SERUM_ALBUM) %>% 
  filter(lag(shap, default=0) != shap)

7. Partial Dependence Plots

Code
imp_features = bt_fit %>% 
  extract_fit_parsnip() %>% vip::vi() %>% filter(Importance > 1E-8)

vars = imp_features %>% 
  pull(Variable) %>% 
  str_remove("na_ind_") %>% #str_remove("_X.+") %>% 
  str_remove("_Unknown") %>% str_remove("_Yes") %>% 
  str_replace("FUNC_STAT_NUM", "FUNC_STAT_CAND_REG") %>% 
  str_replace("ABO(_.+)", "ABO") %>% 
  str_replace("CAND_DIAG_CODE(_.+)", "CAND_DIAG_CODE") %>% 
  {ifelse(. == "CAND_DIAG_CODE", ., str_replace(.,"CAND_DIAG(_.+)", "CAND_DIAG"))} %>%
  str_replace("LIFE_SUPPORT_CAND_REG(_.+)", "LIFE_SUPPORT_CAND_REG") %>% 
str_replace("eGFR_CODED", "eGFR") %>% # since eGFR_CODED isn't in data
  str_replace("RACE(_.+)", "RACE") %>% 
  str_replace("LISTING_CTR_CODE(_X.+)", "LISTING_CTR_CODE") %>% 
  unique() 


walk(vars, plot_additive_effects, model = bt_fit)

Listing Center Only

Code
base_survival = 1-mean(data_train$outcome)
k = 5 # shrinkage/laplace parameter
LC_train = data_train %>% 
  group_by(LISTING_CTR_CODE) %>% 
  summarize(
    n = n(),
    p = 1 - mean(outcome),
    p_survival = (p * n  + base_survival * k) / (n+k)
  )

predict_LC <- function(LISTING_CTR_CODE){
  LC_train$p_survival[LISTING_CTR_CODE] %>% 
    coalesce(base_survival)
}

Status Only

Code
base_survival = 1-mean(data_train$outcome)
k = 5 # shrinkage/laplace parameter
STATUS_train = data_train %>% 
  group_by(STATUS) %>% 
  summarize(
    n = n(),
    p = 1 - mean(outcome),
    p_survival = (p * n  + base_survival * k) / (n+k)
  )

predict_STATUS <- function(STATUS){
  STATUS_train$p_survival[STATUS] %>% 
    coalesce(base_survival)
}

Model Comparison

Test Performance

Code
data_test %>% 
  transmute(
    WL_ID_CODE, 
    outcome = factor(outcome, c(1,0)), 
    lasso =  1-predict(cv_fit, X_test_lasso, type = "response", 
                           s = "lambda.min")[,1], 
    xgboost =  predict(bt_fit, data_test, type = "prob")$.pred_1, 
    RF =  predict(rf_fit, data_test, type = "prob")$.pred_1,
    GAM = 1-predict(fit_gam, data_test_gam, type = "response") %>% as.numeric,
    ensemble = (lasso + xgboost + RF + GAM) / 4,
    LC = 1-predict_LC(data_test$LISTING_CTR_CODE),
    STATUS = 1-predict_STATUS(data_test$STATUS)
  ) %>% 
  pivot_longer(c(lasso, xgboost, RF, GAM, ensemble, LC)) %>% 
  group_by(name) %>% 
  reframe(calc_metrics(outcome, value)) %>% 
  arrange(-auc) %>% 
  mutate_all(\(x) digits(x, 3))

Additive Effects

Code
features = prep(base_rec, data_train) %>% 
  bake(new_data=NULL, all_predictors()) %>% 
  colnames() %>% intersect(colnames(data_train))


plot_multi_effects <- function(var){
  var = rlang::ensym(var)
  df = bind_rows(
    lasso = get_additive_effects(var, model = lasso_fit, data=data_train),
    xgboost = get_additive_effects(var, model = bt_fit, data=data_train), 
    .id = "model"
  ) %>% 
    rename(x = !!var, y = eta)
  
  rug_data = df %>% filter(!is.na(n)) 
  n_bks = nrow(rug_data)
  categorical = is.character(df$x) | is.factor(df$x)
  if(n_bks > 6 & !categorical ){
    # plt = plot_numeric_effects(df[[1]], df[[2]], var, rug_data)
    plt = ggplot(df) + 
    geom_hline(yintercept = 0, color = "orange") +
    geom_line(aes(x, y, color=model))  
  } else{
    # plt = plot_categorical_effects(df[[1]], df[[2]], var, rug_data)
    rug_data = rug_data %>% mutate(x=as.factor(x))
    plt = ggplot(rug_data) + 
    geom_hline(yintercept = 0, color = "orange") +
    geom_col(aes(x, y, fill = model), width = 1/3, 
             position = "dodge") + 
    scale_x_discrete(label = scales::label_wrap(15)) 
  }

    print(
      plt + 
      geom_rug(data = rug_data %>% distinct(x,n), 
       aes(x, linewidth = n), 
       show.legend = FALSE,
       sides = "b", alpha = .25) + 
      labs(x = as.character(var), y = "partial effect") +
      scale_y_continuous(breaks = seq(-10, 10, by = .25)) +
      scale_color_brewer(type = "qual", palette = "Dark2") +
      scale_fill_brewer(type = "qual", palette = "Dark2") +
      coord_cartesian(ylim = c(-1,1)) + 
      theme(
        legend.position = c(0.02, 0.98),
        legend.justification = c("left", "top"),
        legend.title=element_blank()
      )
    )
}


walk(features, plot_multi_effects)

Feature Importance

The feature importance metric is the mean absolute effect (like shap importance).

Code
# Mean Absolute Effect
feature_importance <- function(var){
  var = rlang::ensym(var)
  df = bind_rows(
    lasso = get_additive_effects(var, model = lasso_fit, data=data_train),
    xgboost = get_additive_effects(var, model = bt_fit, data=data_train), 
    .id = "model"
  ) %>% 
    rename(x = !!var, y = eta) %>% 
    filter(!is.na(n)) %>% 
    group_by(model) %>% 
    summarize(Importance = sum(n*abs(y))/sum(n))
}

map(set_names(features), feature_importance) %>% 
  bind_rows(.id = "Variable") %>% 
  mutate(Variable = fct_reorder(Variable, Importance, .fun = "mean")) %>% 
  ggplot(aes(x=Importance, y=Variable, fill = model)) + 
  geom_col(position = "dodge") + 
  scale_fill_brewer(type = "qual", palette = "Dark2")

Center Level Performance

This shows the predictive bias: test survival - predicted survival, by listing center. The top centers have the largest volume (n). Removed centers with \(n \le 2\).

As an example, the third from the top center (19468) has a bias of about -0.10. This means that the actual survival in this center was 10% lower than the models’ predicted.

While I’m not too concerned about the bias in the low volume centers (bottom on plot), there does appear to be modest unaccounted center effects.

Code
test_perf = data_test %>% 
  transmute(
    WL_ID_CODE, 
    LISTING_CTR_CODE,
    REGION,
    outcome = factor(outcome, c(1,0)), 
    lasso =  1-predict(cv_fit, X_test_lasso, type = "response", 
                           s = "lambda.min")[,1], 
    xgboost =  predict(bt_fit, data_test, type = "prob")$.pred_1, 
    RF =  predict(rf_fit, data_test, type = "prob")$.pred_1,
    GAM = 1-predict(fit_gam, data_test_gam, type = "response") %>% as.numeric,
    ensemble = (lasso + xgboost + RF + GAM) / 4,
  ) %>% 
  pivot_longer(c(lasso, xgboost, RF, GAM, ensemble)) %>% 
  mutate(value = 1-value) %>%  # convert to survival predictions
  group_by(name, LISTING_CTR_CODE, REGION) %>% 
  reframe(n = n(), n0 = sum(outcome == 0), p = mean(outcome == 0),
          mean_pred = mean(value), 
          calc_metrics(outcome, 1-value))

test_perf %>% 
  filter(n > 2) %>% 
  mutate(
    bias = p - mean_pred,
    n_diff = n*bias,
    LISTING_CTR_CODE = str_c(LISTING_CTR_CODE, ": n=", n),
    LISTING_CTR_CODE = fct_reorder(LISTING_CTR_CODE, n) #abs(n)
  ) %>% 
  ggplot(aes(x = bias, y = LISTING_CTR_CODE, fill = name)) + 
  geom_col(position = "dodge") + 
  # facet_wrap(~REGION, scales = "free_y", drop=TRUE) +
  labs(fill = "model")

Code
test_perf %>% 
  filter(name == "ensemble") %>% 
  filter(n > 10) %>% 
  mutate(
    bias = p - mean_pred,
    LISTING_CTR_CODE = fct_reorder(LISTING_CTR_CODE, n)
  ) %>% 
  ggplot(aes(x = mean_pred, y = p)) + geom_abline() + 
  geom_point(aes(size=n), shape=1) +
  scale_size_area() + 
  scale_x_continuous(breaks = seq(0, 1, by = .025)) + 
  scale_y_continuous(breaks = seq(0, 1, by = .025)) +
  labs(x = "avg survival prediction", y = "actual survival rate")